{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "from torch.optim import lr_scheduler\n",
    "\n",
    "from model import ECGDataset, ECG_Classifier_LSTM\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from MyEDFImports import load_all_data, load_all_labels, remove_ecg_artifacts, three_stages_transform\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "from tempfile import TemporaryDirectory"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:43:55.615518180Z",
     "start_time": "2023-08-15T15:43:53.946543559Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CN223100.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CP229110.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CX230050.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/DG220020.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/DO223050.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/MyEDFImports.py:41: RuntimeWarning: Channel names are not unique, found duplicates for: {'CHIN EMG'}. Applying running numbers for duplicates.\n",
      "  raw = mne.io.read_raw_edf(path + \"//\" + name)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/LA216100.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/LM230010.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/TK221110.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/VC209100.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/VP214110.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n",
      "Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/WD224010.edf...\n",
      "EDF file detected\n",
      "Setting channel info structure...\n",
      "Creating raw.info structure...\n",
      "NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/MyEDFImports.py:41: RuntimeWarning: Channel names are not unique, found duplicates for: {'CHIN EMG'}. Applying running numbers for duplicates.\n",
      "  raw = mne.io.read_raw_edf(path + \"//\" + name)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<RawEDF | CN223100.edf, 1 x 15611000 (31222.0 s), ~6 kB, data not loaded> with 1561 windows\n",
      "<RawEDF | CP229110.edf, 1 x 20078000 (40156.0 s), ~6 kB, data not loaded> with 2007 windows\n",
      "<RawEDF | CX230050.edf, 1 x 17981000 (35962.0 s), ~6 kB, data not loaded> with 1798 windows\n",
      "<RawEDF | DG220020.edf, 1 x 17756000 (35512.0 s), ~6 kB, data not loaded> with 1775 windows\n",
      "<RawEDF | DO223050.edf, 1 x 18066500 (36133.0 s), ~6 kB, data not loaded> with 1806 windows\n",
      "<RawEDF | LA216100.edf, 1 x 16333500 (32667.0 s), ~6 kB, data not loaded> with 1633 windows\n",
      "<RawEDF | LM230010.edf, 1 x 17246500 (34493.0 s), ~6 kB, data not loaded> with 1724 windows\n",
      "<RawEDF | TK221110.edf, 1 x 15991000 (31982.0 s), ~6 kB, data not loaded> with 1599 windows\n",
      "<RawEDF | VC209100.edf, 1 x 18434500 (36869.0 s), ~6 kB, data not loaded> with 1843 windows\n",
      "<RawEDF | VP214110.edf, 1 x 17252500 (34505.0 s), ~6 kB, data not loaded> with 1725 windows\n",
      "<RawEDF | WD224010.edf, 1 x 17774000 (35548.0 s), ~6 kB, data not loaded> with 1777 windows\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//CN223100.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//CP229110.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//CX230050.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//DG220020.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//DO223050.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//LA216100.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//LM230010.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//TK221110.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//VC209100.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//VP214110.edf_stages.txt\n",
      "loading from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers//WD224010.edf_stages.txt\n",
      "19248\n",
      "15017\n"
     ]
    }
   ],
   "source": [
    "all_unprepared_data = load_all_data()\n",
    "all_unprepared_labels = load_all_labels()\n",
    "\n",
    "print(len(all_unprepared_data))\n",
    "filtered_data, filter_labels = remove_ecg_artifacts(all_unprepared_data, all_unprepared_labels)\n",
    "print(len(filtered_data))\n",
    "# going from 6 labels to three Wake, Nrem, REM\n",
    "filter_labels = three_stages_transform(filter_labels)\n",
    "\n",
    "# this data right now is filtered but not normlized\n",
    "# TODO: normalize"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:44:04.437037807Z",
     "start_time": "2023-08-15T15:43:55.618370221Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "input_size = 64  # Adjust based on the output from conv layers\n",
    "hidden_size = 128\n",
    "num_layers = 2\n",
    "num_classes = 3  # [Wake, NonREM, REM]\n",
    "learning_rate = 0.1\n",
    "batch_size=4"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:44:04.439434366Z",
     "start_time": "2023-08-15T15:44:04.437606799Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "stages = ['train', 'val']\n",
    "dataset_all = ECGDataset(filtered_data, filter_labels)\n",
    "train_data, test_data = random_split(dataset_all, [0.8, 0.2])\n",
    "datasets = {'train': train_data, 'val': test_data}\n",
    "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in\n",
    "               stages}\n",
    "dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = ECG_Classifier_LSTM(input_size, hidden_size, num_layers, num_classes)\n",
    "model = model.to(device)  # Move the model to GPU"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:44:06.123979773Z",
     "start_time": "2023-08-15T15:44:04.442270137Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "# inbalanced dataset 4 to 1 so adding weights to criterion\n",
    "crit_weitghts = torch.tensor([4., 1., 4.]).to(device)\n",
    "criterion = nn.CrossEntropyLoss(weight=crit_weitghts)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)\n",
    "model = ECG_Classifier_LSTM(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes, num_layers=num_layers)\n",
    "model = model.to(device)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:44:06.132042144Z",
     "start_time": "2023-08-15T15:44:06.126983785Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
    "    since = time.time()\n",
    "    # Create a temporary directory to save training checkpoints\n",
    "    with TemporaryDirectory() as tempdir:\n",
    "        best_model_name = f'best_model_params_{type(criterion).__name__}_{type(optimizer).__name__}_{num_epochs}.pt'\n",
    "        best_model_params_path = os.path.join(tempdir, best_model_name)\n",
    "\n",
    "        torch.save(model.state_dict(), best_model_params_path)\n",
    "        best_acc = 0.0\n",
    "\n",
    "        for epoch in range(num_epochs):\n",
    "            print(f'Epoch {epoch + 1}/{num_epochs}')\n",
    "            print('-' * 10)\n",
    "\n",
    "            for phase in ['train', 'val']:\n",
    "                if phase == 'train':\n",
    "                    print('in training')\n",
    "                    model.train()  # Set model to training mode\n",
    "                else:\n",
    "                    print('in validation')\n",
    "                    model.eval()  # Set model to evaluate mode\n",
    "\n",
    "                running_loss = 0.0\n",
    "                running_corrects = 0\n",
    "            # Iterate over data.\n",
    "            for inputs, labels in dataloaders[phase]:\n",
    "                # No idea why now for inputs why I need to transfer it to a float from a double\n",
    "                inputs = inputs.unsqueeze(1).to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "\n",
    "                    outputs = model(inputs)\n",
    "                    _, preds = torch.max(outputs, 1)\n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        print('updating shit')\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                running_loss += loss.item() * inputs.size(0)\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "            if phase == 'train':\n",
    "                print('updating sched')\n",
    "                scheduler.step()\n",
    "\n",
    "            epoch_loss = running_loss / dataset_sizes[phase]\n",
    "            epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
    "\n",
    "            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')\n",
    "\n",
    "            # deep copy the model\n",
    "            if phase == 'val' and epoch_acc > best_acc:\n",
    "                best_acc = epoch_acc\n",
    "                torch.save(model.state_dict(), best_model_params_path)\n",
    "        time_elapsed = time.time() - since\n",
    "        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')\n",
    "        print(f'Best val Acc: {best_acc:4f}')\n",
    "        model.load_state_dict(torch.load(best_model_params_path))\n",
    "    return model"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:47:36.712447865Z",
     "start_time": "2023-08-15T15:47:36.700226881Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "data": {
      "text/plain": "{'train': <torch.utils.data.dataloader.DataLoader at 0x7fb85f016340>,\n 'val': <torch.utils.data.dataloader.DataLoader at 0x7fb85f016160>}"
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataloaders"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:47:39.463722011Z",
     "start_time": "2023-08-15T15:47:39.459968827Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0960 Acc: 0.6510\n",
      "Epoch 2/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0959 Acc: 0.6510\n",
      "Epoch 3/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Epoch 4/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Epoch 5/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 6/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 7/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 8/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 9/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 10/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 11/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 12/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0962 Acc: 0.6510\n",
      "Epoch 13/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0960 Acc: 0.6510\n",
      "Epoch 14/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0959 Acc: 0.6510\n",
      "Epoch 15/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Epoch 16/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Epoch 17/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0963 Acc: 0.6510\n",
      "Epoch 18/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0963 Acc: 0.6510\n",
      "Epoch 19/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0960 Acc: 0.6510\n",
      "Epoch 20/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0963 Acc: 0.6510\n",
      "Epoch 21/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Epoch 22/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0963 Acc: 0.6510\n",
      "Epoch 23/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0960 Acc: 0.6510\n",
      "Epoch 24/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Epoch 25/25\n",
      "----------\n",
      "in training\n",
      "in validation\n",
      "val Loss: 1.0961 Acc: 0.6510\n",
      "Training complete in 1m 7s\n",
      "Best val Acc: 0.651016\n"
     ]
    }
   ],
   "source": [
    "model = train_model(model, criterion, optimizer, exp_lr_scheduler)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([4, 1, 10000])"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs, labels = next(iter(dataloaders['train']))\n",
    "inputs.unsqueeze(1).shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:44:32.006647629Z",
     "start_time": "2023-08-15T15:44:31.721212231Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "3004"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataloaders['train'])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-15T15:44:56.316492995Z",
     "start_time": "2023-08-15T15:44:56.313984775Z"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
